package xredis

import (
	"context"
	"crypto/rand"
	"encoding/base64"
	"errors"
	"fmt"
	"github.com/redis/go-redis/v9"
	"math"
	"time"
)

// DistributedLock 分布式锁结构体
type DistributedLock struct {
	client     *redis.Client // Redis客户端
	key        string        // 锁的键名
	value      string        // 锁的唯一标识（防止误删）
	expiration time.Duration // 锁的过期时间
	retries    int           // 最大重试次数
	retryWait  time.Duration // 重试等待时间
}

// NewDistributedLock 创建分布式锁实例
func NewDistributedLock(client *redis.Client, key string, expiration time.Duration, retries int, retryWait time.Duration) *DistributedLock {
	// 生成唯一标识
	buf := make([]byte, 16)
	_, _ = rand.Read(buf)
	value := base64.StdEncoding.EncodeToString(buf)
	return &DistributedLock{
		client:     client,
		key:        key,
		value:      value,
		expiration: expiration,
		retries:    retries,
		retryWait:  retryWait,
	}
}

// Lock 尝试获取锁（带重试机制）
func (dl *DistributedLock) Lock(ctx context.Context) error {
	for i := 0; i < dl.retries; i++ {
		ok, err := dl.TryLock(ctx)
		if err != nil {
			return err
		}
		if ok {
			return nil
		}

		// 指数退避策略
		sleepTime := dl.retryWait * time.Duration(math.Pow(2, float64(i)))
		select {
		case <-time.After(sleepTime):
		case <-ctx.Done():
			return ctx.Err()
		}
	}
	return errors.New("maximum retries exceeded")
}

// TryLock 尝试获取锁（单次尝试）
func (dl *DistributedLock) TryLock(ctx context.Context) (bool, error) {
	result, err := dl.client.SetNX(ctx, dl.key, dl.value, dl.expiration).Result()
	if err != nil {
		return false, fmt.Errorf("redis setnx error: %v", err)
	}
	return result, nil
}

// Unlock 释放锁
func (dl *DistributedLock) Unlock(ctx context.Context) error {
	// 使用Lua脚本保证原子性
	script := `
if redis.call("get", KEYS) == ARGV then
	return redis.call("del", KEYS)
else
	return 0
end
`
	result, err := dl.client.Eval(ctx, script, []string{dl.key}, dl.value).Int()
	if err != nil {
		return fmt.Errorf("redis eval error: %v", err)
	}
	if result == 0 {
		return errors.New("lock not exists or value mismatch")
	}
	return nil
}
